{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip uninstall --yes pyopenssl\n",
    "!pip install pyopenssl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Colab Setup and Imports { display-mode: \"form\" }\n",
    "# @markdown (double click to see the code)\n",
    "\n",
    "import os\n",
    "import random\n",
    "\n",
    "import git\n",
    "import numpy as np\n",
    "from gym import spaces\n",
    "\n",
    "%matplotlib inline\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "repo = git.Repo(\".\", search_parent_directories=True)\n",
    "dir_path = repo.working_tree_dir\n",
    "data_path = os.path.join(dir_path, \"data\")\n",
    "os.chdir(dir_path)\n",
    "\n",
    "from PIL import Image\n",
    "\n",
    "import habitat\n",
    "from habitat.core.logging import logger\n",
    "from habitat.core.registry import registry\n",
    "from habitat.sims.habitat_simulator.actions import HabitatSimActions\n",
    "from habitat.tasks.nav.nav import NavigationTask\n",
    "from habitat_baselines.common.baseline_registry import baseline_registry\n",
    "from habitat_baselines.config.default import get_config as get_baselines_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python -m habitat_sim.utils.datasets_download --uids mp3d_example_scene --data-path {data_path} --no-replace"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# @title Define Observation Display Utility Function { display-mode: \"form\" }\n",
    "\n",
    "# @markdown A convenient function that displays sensor observations with matplotlib.\n",
    "\n",
    "# @markdown (double click to see the code)\n",
    "\n",
    "\n",
    "# Change to do something like this maybe: https://stackoverflow.com/a/41432704\n",
    "def display_sample(\n",
    "    rgb_obs, semantic_obs=np.array([]), depth_obs=np.array([])\n",
    "):  # noqa: B006\n",
    "    from habitat_sim.utils.common import d3_40_colors_rgb\n",
    "\n",
    "    rgb_img = Image.fromarray(rgb_obs, mode=\"RGB\")\n",
    "\n",
    "    arr = [rgb_img]\n",
    "    titles = [\"rgb\"]\n",
    "    if semantic_obs.size != 0:\n",
    "        semantic_img = Image.new(\n",
    "            \"P\", (semantic_obs.shape[1], semantic_obs.shape[0])\n",
    "        )\n",
    "        semantic_img.putpalette(d3_40_colors_rgb.flatten())\n",
    "        semantic_img.putdata((semantic_obs.flatten() % 40).astype(np.uint8))\n",
    "        semantic_img = semantic_img.convert(\"RGBA\")\n",
    "        arr.append(semantic_img)\n",
    "        titles.append(\"semantic\")\n",
    "\n",
    "    if depth_obs.size != 0:\n",
    "        depth_img = Image.fromarray(\n",
    "            (depth_obs / 10 * 255).astype(np.uint8), mode=\"L\"\n",
    "        )\n",
    "        arr.append(depth_img)\n",
    "        titles.append(\"depth\")\n",
    "\n",
    "    plt.figure(figsize=(12, 8))\n",
    "    for i, data in enumerate(arr):\n",
    "        ax = plt.subplot(1, 3, i + 1)\n",
    "        ax.axis(\"off\")\n",
    "        ax.set_title(titles[i])\n",
    "        plt.imshow(data)\n",
    "    plt.show(block=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setup PointNav Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    config = habitat.get_config(\n",
    "        config_path=os.path.join(\n",
    "            dir_path,\n",
    "            \"habitat-lab/habitat/config/benchmark/nav/pointnav/pointnav_habitat_test.yaml\",\n",
    "        ),\n",
    "        overrides=[\n",
    "            \"habitat.environment.max_episode_steps=10\",\n",
    "            \"habitat.environment.iterator_options.shuffle=False\",\n",
    "        ],\n",
    "    )\n",
    "\n",
    "    try:\n",
    "        env.close()  # type: ignore[has-type]\n",
    "    except NameError:\n",
    "        pass\n",
    "    env = habitat.Env(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    action = None\n",
    "    obs = env.reset()\n",
    "    valid_actions = [\"turn_left\", \"turn_right\", \"move_forward\", \"stop\"]\n",
    "    interactive_control = False  # @param {type:\"boolean\"}\n",
    "    while action != \"stop\":\n",
    "        display_sample(obs[\"rgb\"])\n",
    "        print(\n",
    "            \"distance to goal: {:.2f}\".format(\n",
    "                obs[\"pointgoal_with_gps_compass\"][0]\n",
    "            )\n",
    "        )\n",
    "        print(\n",
    "            \"angle to goal (radians): {:.2f}\".format(\n",
    "                obs[\"pointgoal_with_gps_compass\"][1]\n",
    "            )\n",
    "        )\n",
    "        if interactive_control:\n",
    "            action = input(\n",
    "                \"enter action out of {}:\\n\".format(\", \".join(valid_actions))\n",
    "            )\n",
    "            assert (\n",
    "                action in valid_actions\n",
    "            ), \"invalid action {} entered, choose one amongst \" + \",\".join(\n",
    "                valid_actions\n",
    "            )\n",
    "        else:\n",
    "            action = valid_actions.pop()\n",
    "        obs = env.step(\n",
    "            {\n",
    "                \"action\": action,\n",
    "            }\n",
    "        )\n",
    "\n",
    "    env.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    print(env.get_metrics())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## RL Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    config = get_baselines_config(\"pointnav/ppo_pointnav_example.yaml\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set random seeds\n",
    "if __name__ == \"__main__\":\n",
    "    seed = \"42\"  # @param {type:\"string\"}\n",
    "    steps_in_thousands = \"10\"  # @param {type:\"string\"}\n",
    "\n",
    "    with habitat.config.read_write(config):\n",
    "        config.habitat.seed = int(seed)\n",
    "        config.habitat_baselines.total_num_steps = int(steps_in_thousands)\n",
    "        config.habitat_baselines.log_interval = 1\n",
    "\n",
    "    random.seed(config.habitat.seed)\n",
    "    np.random.seed(config.habitat.seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    trainer_init = baseline_registry.get_trainer(\n",
    "        config.habitat_baselines.trainer_name\n",
    "    )\n",
    "    trainer = trainer_init(config)\n",
    "    trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 0
   },
   "outputs": [],
   "source": [
    "# @markdown (double click to see the code)\n",
    "\n",
    "# example tensorboard visualization\n",
    "# for more details refer to [link](https://github.com/facebookresearch/habitat-lab/tree/main/habitat-baselines/habitat_baselines#additional-utilities).\n",
    "\n",
    "try:\n",
    "    from IPython import display\n",
    "\n",
    "    with open(\"./res/img/tensorboard_video_demo.gif\", \"rb\") as f:\n",
    "        display.display(display.Image(data=f.read(), format=\"png\"))\n",
    "except ImportError:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Key Concepts\n",
    "\n",
    "All the concepts link to their definitions:\n",
    "\n",
    "1. [`habitat.sims.habitat_simulator.HabitatSim`](https://github.com/facebookresearch/habitat-lab/blob/main/habitat-lab/habitat/sims/habitat_simulator/habitat_simulator.py#L254)\n",
    "Thin wrapper over `habitat_sim` providing seamless integration with experimentation framework.\n",
    "\n",
    "\n",
    "2. [`habitat.core.env.Env`](https://github.com/facebookresearch/habitat-lab/blob/main/habitat-lab/habitat/core/env.py#L26)\n",
    "Abstraction for the universe of agent, task and simulator. Agents that you train and evaluate operate inside the environment.\n",
    "\n",
    "\n",
    "3. [`habitat.core.env.RLEnv`](https://github.com/facebookresearch/habitat-lab/blob/main/habitat-lab/habitat/core/env.py#L347)\n",
    "Extends the `Env` class for reinforcement learning by defining the reward and other required components.\n",
    "\n",
    "\n",
    "4. [`habitat.core.embodied_task.EmbodiedTask`](https://github.com/facebookresearch/habitat-lab/blob/main/habitat-lab/habitat/core/embodied_task.py#L201)\n",
    "Defines the task that the agent needs to solve. This class holds the definition of observation space, action space, measures, simulator usage. Eg: PointNav, ObjectNav.\n",
    "\n",
    "\n",
    "5. [`habitat.core.dataset.Dataset`](https://github.com/facebookresearch/habitat-lab/blob/main/habitat-lab/habitat/core/dataset.py#L107)\n",
    "Wrapper over information required for the dataset of embodied task, contains definition and interaction with an `episode`.\n",
    "\n",
    "\n",
    "6. [`habitat.core.embodied_task.Measure`](https://github.com/facebookresearch/habitat-lab/blob/main/habitat-lab/habitat/core/embodied_task.py#L80)\n",
    "Defines the metrics for embodied task, eg: [SPL](https://github.com/facebookresearch/habitat-lab/blob/main/habitat-lab/habitat/tasks/nav/nav.py#L565).\n",
    "\n",
    "\n",
    "7. [`habitat_baselines`](https://github.com/facebookresearch/habitat-lab/tree/main/habitat-baselines/habitat_baselines)\n",
    "RL, SLAM, heuristic baseline implementations for the different embodied tasks."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a new Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    config = habitat.get_config(\n",
    "        config_path=os.path.join(\n",
    "            dir_path,\n",
    "            \"habitat-lab/habitat/config/benchmark/nav/pointnav/pointnav_habitat_test.yaml\",\n",
    "        ),\n",
    "        overrides=[\n",
    "            \"habitat.environment.max_episode_steps=10\",\n",
    "            \"habitat.environment.iterator_options.shuffle=False\",\n",
    "        ],\n",
    "    )\n",
    "\n",
    "\n",
    "@registry.register_task(name=\"TestNav-v0\")\n",
    "class NewNavigationTask(NavigationTask):\n",
    "    def __init__(self, config, sim, dataset):\n",
    "        logger.info(\"Creating a new type of task\")\n",
    "        super().__init__(config=config, sim=sim, dataset=dataset)\n",
    "\n",
    "    def _check_episode_is_active(self, *args, **kwargs):\n",
    "        logger.info(\n",
    "            \"Current agent position: {}\".format(self._sim.get_agent_state())\n",
    "        )\n",
    "        collision = self._sim.previous_step_collided\n",
    "        stop_called = not getattr(self, \"is_stop_called\", False)\n",
    "        return collision or stop_called\n",
    "\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    with habitat.config.read_write(config):\n",
    "        config.habitat.task.type = \"TestNav-v0\"\n",
    "\n",
    "    try:\n",
    "        env.close()\n",
    "    except NameError:\n",
    "        pass\n",
    "    env = habitat.Env(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "    action = None\n",
    "    env.reset()\n",
    "    valid_actions = [\"turn_left\", \"turn_right\", \"move_forward\", \"stop\"]\n",
    "    interactive_control = False  # @param {type:\"boolean\"}\n",
    "    while env.episode_over is not True:\n",
    "        display_sample(obs[\"rgb\"])\n",
    "        if interactive_control:\n",
    "            action = input(\n",
    "                \"enter action out of {}:\\n\".format(\", \".join(valid_actions))\n",
    "            )\n",
    "            assert (\n",
    "                action in valid_actions\n",
    "            ), \"invalid action {} entered, choose one amongst \" + \",\".join(\n",
    "                valid_actions\n",
    "            )\n",
    "        else:\n",
    "            action = valid_actions.pop()\n",
    "        obs = env.step(\n",
    "            {\n",
    "                \"action\": action,\n",
    "                \"action_args\": None,\n",
    "            }\n",
    "        )\n",
    "        print(\"Episode over:\", env.episode_over)\n",
    "\n",
    "    env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "source": [
    "## Create a new Sensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@registry.register_sensor(name=\"agent_position_sensor\")\n",
    "class AgentPositionSensor(habitat.Sensor):\n",
    "    def __init__(self, sim, config, **kwargs):\n",
    "        super().__init__(config=config)\n",
    "        self._sim = sim\n",
    "\n",
    "    # Defines the name of the sensor in the sensor suite dictionary\n",
    "    def _get_uuid(self, *args, **kwargs):\n",
    "        return \"agent_position\"\n",
    "\n",
    "    # Defines the type of the sensor\n",
    "    def _get_sensor_type(self, *args, **kwargs):\n",
    "        return habitat.SensorTypes.POSITION\n",
    "\n",
    "    # Defines the size and range of the observations of the sensor\n",
    "    def _get_observation_space(self, *args, **kwargs):\n",
    "        return spaces.Box(\n",
    "            low=np.finfo(np.float32).min,\n",
    "            high=np.finfo(np.float32).max,\n",
    "            shape=(3,),\n",
    "            dtype=np.float32,\n",
    "        )\n",
    "\n",
    "    # This is called whenever reset is called or an action is taken\n",
    "    def get_observation(self, observations, *args, episode, **kwargs):\n",
    "        return self._sim.get_agent_state().position"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if __name__ == \"__main__\":\n",
    "    config = habitat.get_config(\n",
    "        config_path=os.path.join(\n",
    "            dir_path,\n",
    "            \"habitat-lab/habitat/config/benchmark/nav/pointnav/pointnav_habitat_test.yaml\",\n",
    "        ),\n",
    "        overrides=[\n",
    "            \"habitat.environment.max_episode_steps=10\",\n",
    "            \"habitat.environment.iterator_options.shuffle=False\",\n",
    "        ],\n",
    "    )\n",
    "\n",
    "    from habitat.config.default_structured_configs import LabSensorConfig\n",
    "\n",
    "    # We use the base sensor config, but you could also define your own\n",
    "    # AgentPositionSensorConfig that inherits from LabSensorConfig\n",
    "\n",
    "    with habitat.config.read_write(config):\n",
    "        # Now define the config for the sensor\n",
    "        config.habitat.task.lab_sensors[\n",
    "            \"agent_position_sensor\"\n",
    "        ] = LabSensorConfig(type=\"agent_position_sensor\")\n",
    "\n",
    "    try:\n",
    "        env.close()\n",
    "    except NameError:\n",
    "        pass\n",
    "    env = habitat.Env(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    obs = env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    obs.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    print(obs[\"agent_position\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a new Agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "lines_to_next_cell": 2
   },
   "outputs": [],
   "source": [
    "# An example agent which can be submitted to habitat-challenge.\n",
    "# To participate and for more details refer to:\n",
    "# - https://aihabitat.org/challenge/2020/\n",
    "# - https://github.com/facebookresearch/habitat-challenge\n",
    "\n",
    "\n",
    "class ForwardOnlyAgent(habitat.Agent):\n",
    "    def __init__(self, success_distance, goal_sensor_uuid):\n",
    "        self.dist_threshold_to_stop = success_distance\n",
    "        self.goal_sensor_uuid = goal_sensor_uuid\n",
    "\n",
    "    def reset(self):\n",
    "        pass\n",
    "\n",
    "    def is_goal_reached(self, observations):\n",
    "        dist = observations[self.goal_sensor_uuid][0]\n",
    "        return dist <= self.dist_threshold_to_stop\n",
    "\n",
    "    def act(self, observations):\n",
    "        if self.is_goal_reached(observations):\n",
    "            action = HabitatSimActions.stop\n",
    "        else:\n",
    "            action = HabitatSimActions.move_forward\n",
    "        return {\"action\": action}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Other Examples\n",
    "\n",
    "[Create a new action space](https://github.com/facebookresearch/habitat-lab/blob/main/examples/new_actions.py)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Habitat Lab",
   "provenance": []
  },
  "jupytext": {
   "cell_metadata_filter": "-all",
   "formats": "nb_python//py:percent,notebooks//ipynb",
   "notebook_metadata_filter": "all"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
